Skip to content

Adding Torch support for Model Parallel#22394

Draft
buildwithsuhana wants to merge 15 commits intokeras-team:masterfrom
buildwithsuhana:mp_parallelize
Draft

Adding Torch support for Model Parallel#22394
buildwithsuhana wants to merge 15 commits intokeras-team:masterfrom
buildwithsuhana:mp_parallelize

Conversation

@buildwithsuhana
Copy link
Collaborator

This PR introduces torch backend support for Model Parallelism (MP) in Keras. It aligns the internal distribution_lib implementations to ensure that high-level Keras Distribution APIs (like DeviceMesh, LayoutMap, and ModelParallel) behave consistently regardless of the underlying framework. Leveraged PyTorch DTensor and DeviceMesh to handle sharding and replication.

Design document: go/distributionLib

Kaggle link testing model parallel for torch and jax backend (using keras_hub opt model): https://www.kaggle.com/code/buildwithsuhana/dtensor-model-parallel-data-parallel-for-torch

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances Keras by adding support for model parallelism with the Torch backend. It introduces necessary components for distributing tensors and variables across multiple devices, enabling larger models to be trained efficiently. The changes include modifications to core tensor operations to be aware of sharding and replication, as well as utilities for initializing and managing distributed environments.

Highlights

  • Model Parallelism Support: Introduces torch backend support for Model Parallelism (MP) in Keras, leveraging PyTorch DTensor and DeviceMesh for sharding and replication.
  • Distribution Library Alignment: Aligns internal distribution_lib implementations to ensure consistent behavior of high-level Keras Distribution APIs (DeviceMesh, LayoutMap, ModelParallel) across different frameworks.
  • Sharding-Aware Operations: Implements sharding-aware operations for torch tensors, including getitem, unbind, broadcast_to, einsum, and detach, to handle distributed tensors correctly.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Activity
  • The PR introduces torch backend support for Model Parallelism (MP) in Keras.
  • It aligns the internal distribution_lib implementations to ensure that high-level Keras Distribution APIs behave consistently regardless of the underlying framework.
  • Leveraged PyTorch DTensor and DeviceMesh to handle sharding and replication.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for model parallelism with the PyTorch backend. The changes are extensive, touching core backend components, distribution libraries, and some ops. The overall approach of using PyTorch's DTensor and parallelize_module is sound. I've identified a few areas for improvement, mainly concerning code structure and backend abstractions. Specifically, some backend-specific logic has been added to the generic keras.ops module, which should be moved to the torch-specific backend implementation to maintain clean separation. I also have a suggestion to improve code clarity in the Variable class implementation.

@codecov-commenter
Copy link

codecov-commenter commented Mar 10, 2026

Codecov Report

❌ Patch coverage is 37.56757% with 231 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.75%. Comparing base (e4834f6) to head (02dba75).
⚠️ Report is 4 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/torch/distribution_lib.py 29.44% 110 Missing and 17 partials ⚠️
keras/src/backend/torch/core.py 54.96% 49 Missing and 19 partials ⚠️
keras/src/ops/nn.py 0.00% 32 Missing and 1 partial ⚠️
keras/src/backend/torch/layer.py 40.00% 2 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #22394      +/-   ##
==========================================
- Coverage   82.99%   82.75%   -0.25%     
==========================================
  Files         596      597       +1     
  Lines       66423    66945     +522     
  Branches    10353    10461     +108     
==========================================
+ Hits        55130    55400     +270     
- Misses       8665     8870     +205     
- Partials     2628     2675      +47     
Flag Coverage Δ
keras 82.58% <37.56%> (-0.25%) ⬇️
keras-jax 60.40% <19.72%> (-0.34%) ⬇️
keras-numpy 54.66% <20.27%> (-0.29%) ⬇️
keras-openvino 49.26% <20.27%> (-0.02%) ⬇️
keras-tensorflow 61.63% <20.27%> (-0.34%) ⬇️
keras-torch 60.56% <37.56%> (-0.25%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@buildwithsuhana buildwithsuhana marked this pull request as draft March 12, 2026 02:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants